/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 * SimpleCart.java
 * Copyright (C) 2007 Haijian Shi
 *
 */

package weka.classifiers.trees;

import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableClassifier;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.matrix.Matrix;

/**
 * <!-- globalinfo-start --> Class implementing minimal cost-complexity pruning.<br/>
 * Note when dealing with missing values, use "fractional instances" method
 * instead of surrogate split method.<br/>
 * <br/>
 * For more information, see:<br/>
 * <br/>
 * Leo Breiman, Jerome H. Friedman, Richard A. Olshen, Charles J. Stone (1984).
 * Classification and Regression Trees. Wadsworth International Group, Belmont,
 * California.
 * <p/>
 * <!-- globalinfo-end -->
 * 
 * <!-- technical-bibtex-start --> BibTeX:
 * 
 * <pre>
 * &#64;book{Breiman1984,
 *    address = {Belmont, California},
 *    author = {Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone},
 *    publisher = {Wadsworth International Group},
 *    title = {Classification and Regression Trees},
 *    year = {1984}
 * }
 * </pre>
 * <p/>
 * <!-- technical-bibtex-end -->
 * 
 * <!-- options-start --> Valid options are:
 * <p/>
 * 
 * <pre>
 * -S &lt;num&gt;
 *  Random number seed.
 *  (default 1)
 * </pre>
 * 
 * <pre>
 * -D
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
 * </pre>
 * 
 * <pre>
 * -M &lt;min no&gt;
 *  The minimal number of instances at the terminal nodes.
 *  (default 2)
 * </pre>
 * 
 * <pre>
 * -N &lt;num folds&gt;
 *  The number of folds used in the minimal cost-complexity pruning.
 *  (default 5)
 * </pre>
 * 
 * <pre>
 * -U
 *  Don't use the minimal cost-complexity pruning.
 *  (default yes).
 * </pre>
 * 
 * <pre>
 * -H
 *  Don't use the heuristic method for binary split.
 *  (default true).
 * </pre>
 * 
 * <pre>
 * -A
 *  Use 1 SE rule to make pruning decision.
 *  (default no).
 * </pre>
 * 
 * <pre>
 * -C
 *  Percentage of training data size (0-1].
 *  (default 1).
 * </pre>
 * 
 * <!-- options-end -->
 * 
 * @author Haijian Shi (hs69@cs.waikato.ac.nz)
 * @version $Revision$
 */
public class SimpleCart extends RandomizableClassifier implements
  AdditionalMeasureProducer, TechnicalInformationHandler {

  /** For serialization. */
  private static final long serialVersionUID = 4154189200352566053L;

  /** Training data. */
  protected Instances m_train;

  /** Successor nodes. */
  protected SimpleCart[] m_Successors;

  /** Attribute used to split data. */
  protected Attribute m_Attribute;

  /** Split point for a numeric attribute. */
  protected double m_SplitValue;

  /** Split subset used to split data for nominal attributes. */
  protected String m_SplitString;

  /** Class value if the node is leaf. */
  protected double m_ClassValue;

  /** Class attriubte of data. */
  protected Attribute m_ClassAttribute;

  /** Minimum number of instances in at the terminal nodes. */
  protected double m_minNumObj = 2;

  /** Number of folds for minimal cost-complexity pruning. */
  protected int m_numFoldsPruning = 5;

  /** Alpha-value (for pruning) at the node. */
  protected double m_Alpha;

  /** Number of training examples misclassified by the model (subtree rooted). */
  protected double m_numIncorrectModel;

  /**
   * Number of training examples misclassified by the model (subtree not
   * rooted).
   */
  protected double m_numIncorrectTree;

  /** Indicate if the node is a leaf node. */
  protected boolean m_isLeaf;

  /** If use minimal cost-compexity pruning. */
  protected boolean m_Prune = true;

  /** Total number of instances used to build the classifier. */
  protected int m_totalTrainInstances;

  /** Proportion for each branch. */
  protected double[] m_Props;

  /** Class probabilities. */
  protected double[] m_ClassProbs = null;

  /**
   * Distributions of leaf node (or temporary leaf node in minimal
   * cost-complexity pruning)
   */
  protected double[] m_Distribution;

  /**
   * If use huristic search for nominal attributes in multi-class problems
   * (default true).
   */
  protected boolean m_Heuristic = true;

  /** If use the 1SE rule to make final decision tree. */
  protected boolean m_UseOneSE = false;

  /** Training data size. */
  protected double m_SizePer = 1;

  /**
   * Return a description suitable for displaying in the explorer/experimenter.
   * 
   * @return a description suitable for displaying in the explorer/experimenter
   */
  public String globalInfo() {
    return "Class implementing minimal cost-complexity pruning.\n"
      + "Note when dealing with missing values, use \"fractional "
      + "instances\" method instead of surrogate split method.\n\n"
      + "For more information, see:\n\n" + getTechnicalInformation().toString();
  }

  /**
   * Returns an instance of a TechnicalInformation object, containing detailed
   * information about the technical background of this class, e.g., paper
   * reference or book this class is based on.
   * 
   * @return the technical information about this class
   */
  @Override
  public TechnicalInformation getTechnicalInformation() {
    TechnicalInformation result;

    result = new TechnicalInformation(Type.BOOK);
    result
      .setValue(Field.AUTHOR,
        "Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone");
    result.setValue(Field.YEAR, "1984");
    result.setValue(Field.TITLE, "Classification and Regression Trees");
    result.setValue(Field.PUBLISHER, "Wadsworth International Group");
    result.setValue(Field.ADDRESS, "Belmont, California");

    return result;
  }

  /**
   * Returns default capabilities of the classifier.
   * 
   * @return the capabilities of this classifier
   */
  @Override
  public Capabilities getCapabilities() {
    Capabilities result = super.getCapabilities();
    result.disableAll();

    // attributes
    result.enable(Capability.NOMINAL_ATTRIBUTES);
    result.enable(Capability.NUMERIC_ATTRIBUTES);
    result.enable(Capability.MISSING_VALUES);

    // class
    result.enable(Capability.NOMINAL_CLASS);

    return result;
  }

  /**
   * Build the classifier.
   * 
   * @param data the training instances
   * @throws Exception if something goes wrong
   */
  @Override
  public void buildClassifier(Instances data) throws Exception {

    getCapabilities().testWithFail(data);
    data = new Instances(data);
    data.deleteWithMissingClass();

    // unpruned CART decision tree
    if (!m_Prune) {

      // calculate sorted indices and weights, and compute initial class counts.
      int[][] sortedIndices = new int[data.numAttributes()][0];
      double[][] weights = new double[data.numAttributes()][0];
      double[] classProbs = new double[data.numClasses()];
      double totalWeight = computeSortedInfo(data, sortedIndices, weights,
        classProbs);

      makeTree(data, data.numInstances(), sortedIndices, weights, classProbs,
        totalWeight, m_minNumObj, m_Heuristic);
      return;
    }

    Random random = new Random(m_Seed);
    Instances cvData = new Instances(data);
    cvData.randomize(random);
    cvData = new Instances(cvData, 0,
      (int) (cvData.numInstances() * m_SizePer) - 1);
    cvData.stratify(m_numFoldsPruning);

    double[][] alphas = new double[m_numFoldsPruning][];
    double[][] errors = new double[m_numFoldsPruning][];

    // calculate errors and alphas for each fold
    for (int i = 0; i < m_numFoldsPruning; i++) {

      // for every fold, grow tree on training set and fix error on test set.
      Instances train = cvData.trainCV(m_numFoldsPruning, i);
      Instances test = cvData.testCV(m_numFoldsPruning, i);

      // calculate sorted indices and weights, and compute initial class counts
      // for each fold
      int[][] sortedIndices = new int[train.numAttributes()][0];
      double[][] weights = new double[train.numAttributes()][0];
      double[] classProbs = new double[train.numClasses()];
      double totalWeight = computeSortedInfo(train, sortedIndices, weights,
        classProbs);

      makeTree(train, train.numInstances(), sortedIndices, weights, classProbs,
        totalWeight, m_minNumObj, m_Heuristic);

      int numNodes = numInnerNodes();
      alphas[i] = new double[numNodes + 2];
      errors[i] = new double[numNodes + 2];

      // prune back and log alpha-values and errors on test set
      prune(alphas[i], errors[i], test);
    }

    // calculate sorted indices and weights, and compute initial class counts on
    // all training instances
    int[][] sortedIndices = new int[data.numAttributes()][0];
    double[][] weights = new double[data.numAttributes()][0];
    double[] classProbs = new double[data.numClasses()];
    double totalWeight = computeSortedInfo(data, sortedIndices, weights,
      classProbs);

    // build tree using all the data
    makeTree(data, data.numInstances(), sortedIndices, weights, classProbs,
      totalWeight, m_minNumObj, m_Heuristic);

    int numNodes = numInnerNodes();

    double[] treeAlphas = new double[numNodes + 2];

    // prune back and log alpha-values
    int iterations = prune(treeAlphas, null, null);

    double[] treeErrors = new double[numNodes + 2];

    // for each pruned subtree, find the cross-validated error
    for (int i = 0; i <= iterations; i++) {
      // compute midpoint alphas
      double alpha = Math.sqrt(treeAlphas[i] * treeAlphas[i + 1]);
      double error = 0;
      for (int k = 0; k < m_numFoldsPruning; k++) {
        int l = 0;
        while (alphas[k][l] <= alpha) {
          l++;
        }
        error += errors[k][l - 1];
      }
      treeErrors[i] = error / m_numFoldsPruning;
    }

    // find best alpha
    int best = -1;
    double bestError = Double.MAX_VALUE;
    for (int i = iterations; i >= 0; i--) {
      if (treeErrors[i] < bestError) {
        bestError = treeErrors[i];
        best = i;
      }
    }

    // 1 SE rule to choose expansion
    if (m_UseOneSE) {
      double oneSE = Math.sqrt(bestError * (1 - bestError)
        / (data.numInstances()));
      for (int i = iterations; i >= 0; i--) {
        if (treeErrors[i] <= bestError + oneSE) {
          best = i;
          break;
        }
      }
    }

    double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]);

    // "unprune" final tree (faster than regrowing it)
    unprune();
    prune(bestAlpha);
  }

  /**
   * Make binary decision tree recursively.
   * 
   * @param data the training instances
   * @param totalInstances total number of instances
   * @param sortedIndices sorted indices of the instances
   * @param weights weights of the instances
   * @param classProbs class probabilities
   * @param totalWeight total weight of instances
   * @param minNumObj minimal number of instances at leaf nodes
   * @param useHeuristic if use heuristic search for nominal attributes in
   *          multi-class problem
   * @throws Exception if something goes wrong
   */
  protected void makeTree(Instances data, int totalInstances,
    int[][] sortedIndices, double[][] weights, double[] classProbs,
    double totalWeight, double minNumObj, boolean useHeuristic)
    throws Exception {

    // if no instances have reached this node (normally won't happen)
    if (totalWeight == 0) {
      m_Attribute = null;
      m_ClassValue = Utils.missingValue();
      m_Distribution = new double[data.numClasses()];
      return;
    }

    m_totalTrainInstances = totalInstances;
    m_isLeaf = true;
    m_Successors = null;

    m_ClassProbs = new double[classProbs.length];
    m_Distribution = new double[classProbs.length];
    System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
    System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length);
    if (Utils.sum(m_ClassProbs) != 0) {
      Utils.normalize(m_ClassProbs);
    }

    // Compute class distributions and value of splitting
    // criterion for each attribute
    double[][][] dists = new double[data.numAttributes()][0][0];
    double[][] props = new double[data.numAttributes()][0];
    double[][] totalSubsetWeights = new double[data.numAttributes()][2];
    double[] splits = new double[data.numAttributes()];
    String[] splitString = new String[data.numAttributes()];
    double[] giniGains = new double[data.numAttributes()];

    // for each attribute find split information
    for (int i = 0; i < data.numAttributes(); i++) {
      Attribute att = data.attribute(i);
      if (i == data.classIndex()) {
        continue;
      }
      if (att.isNumeric()) {
        // numeric attribute
        splits[i] = numericDistribution(props, dists, att, sortedIndices[i],
          weights[i], totalSubsetWeights, giniGains, data);
      } else {
        // nominal attribute
        splitString[i] = nominalDistribution(props, dists, att,
          sortedIndices[i], weights[i], totalSubsetWeights, giniGains, data,
          useHeuristic);
      }
    }

    // Find best attribute (split with maximum Gini gain)
    int attIndex = Utils.maxIndex(giniGains);
    m_Attribute = data.attribute(attIndex);

    m_train = new Instances(data, sortedIndices[attIndex].length);
    for (int i = 0; i < sortedIndices[attIndex].length; i++) {
      Instance inst = data.instance(sortedIndices[attIndex][i]);
      Instance instCopy = (Instance) inst.copy();
      instCopy.setWeight(weights[attIndex][i]);
      m_train.add(instCopy);
    }

    // Check if node does not contain enough instances, or if it can not be
    // split,
    // or if it is pure. If does, make leaf.
    if (totalWeight < 2 * minNumObj || giniGains[attIndex] == 0
      || props[attIndex][0] == 0 || props[attIndex][1] == 0) {
      makeLeaf(data);
    }

    else {
      m_Props = props[attIndex];
      int[][][] subsetIndices = new int[2][data.numAttributes()][0];
      double[][][] subsetWeights = new double[2][data.numAttributes()][0];

      // numeric split
      if (m_Attribute.isNumeric()) {
        m_SplitValue = splits[attIndex];
      } else {
        m_SplitString = splitString[attIndex];
      }

      splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue,
        m_SplitString, sortedIndices, weights, data);

      // If split of the node results in a node with less than minimal number of
      // isntances,
      // make the node leaf node.
      if (subsetIndices[0][attIndex].length < minNumObj
        || subsetIndices[1][attIndex].length < minNumObj) {
        makeLeaf(data);
        return;
      }

      // Otherwise, split the node.
      m_isLeaf = false;
      m_Successors = new SimpleCart[2];
      for (int i = 0; i < 2; i++) {
        m_Successors[i] = new SimpleCart();
        m_Successors[i].makeTree(data, m_totalTrainInstances, subsetIndices[i],
          subsetWeights[i], dists[attIndex][i],
          totalSubsetWeights[attIndex][i], minNumObj, useHeuristic);
      }
    }
  }

  /**
   * Prunes the original tree using the CART pruning scheme, given a
   * cost-complexity parameter alpha.
   * 
   * @param alpha the cost-complexity parameter
   * @throws Exception if something goes wrong
   */
  public void prune(double alpha) throws Exception {

    Vector<SimpleCart> nodeList;

    // determine training error of pruned subtrees (both with and without
    // replacing a subtree),
    // and calculate alpha-values from them
    modelErrors();
    treeErrors();
    calculateAlphas();

    // get list of all inner nodes in the tree
    nodeList = getInnerNodes();

    boolean prune = (nodeList.size() > 0);
    double preAlpha = Double.MAX_VALUE;
    while (prune) {

      // select node with minimum alpha
      SimpleCart nodeToPrune = nodeToPrune(nodeList);

      // want to prune if its alpha is smaller than alpha
      if (nodeToPrune.m_Alpha > alpha) {
        break;
      }

      nodeToPrune.makeLeaf(nodeToPrune.m_train);

      // normally would not happen
      if (nodeToPrune.m_Alpha == preAlpha) {
        nodeToPrune.makeLeaf(nodeToPrune.m_train);
        treeErrors();
        calculateAlphas();
        nodeList = getInnerNodes();
        prune = (nodeList.size() > 0);
        continue;
      }
      preAlpha = nodeToPrune.m_Alpha;

      // update tree errors and alphas
      treeErrors();
      calculateAlphas();

      nodeList = getInnerNodes();
      prune = (nodeList.size() > 0);
    }
  }

  /**
   * Method for performing one fold in the cross-validation of minimal
   * cost-complexity pruning. Generates a sequence of alpha-values with error
   * estimates for the corresponding (partially pruned) trees, given the test
   * set of that fold.
   * 
   * @param alphas array to hold the generated alpha-values
   * @param errors array to hold the corresponding error estimates
   * @param test test set of that fold (to obtain error estimates)
   * @return the iteration of the pruning
   * @throws Exception if something goes wrong
   */
  public int prune(double[] alphas, double[] errors, Instances test)
    throws Exception {

    Vector<SimpleCart> nodeList;

    // determine training error of subtrees (both with and without replacing a
    // subtree),
    // and calculate alpha-values from them
    modelErrors();
    treeErrors();
    calculateAlphas();

    // get list of all inner nodes in the tree
    nodeList = getInnerNodes();

    boolean prune = (nodeList.size() > 0);

    // alpha_0 is always zero (unpruned tree)
    alphas[0] = 0;

    Evaluation eval;

    // error of unpruned tree
    if (errors != null) {
      eval = new Evaluation(test);
      eval.evaluateModel(this, test);
      errors[0] = eval.errorRate();
    }

    int iteration = 0;
    double preAlpha = Double.MAX_VALUE;
    while (prune) {

      iteration++;

      // get node with minimum alpha
      SimpleCart nodeToPrune = nodeToPrune(nodeList);

      // do not set m_sons null, want to unprune
      nodeToPrune.m_isLeaf = true;

      // normally would not happen
      if (nodeToPrune.m_Alpha == preAlpha) {
        iteration--;
        treeErrors();
        calculateAlphas();
        nodeList = getInnerNodes();
        prune = (nodeList.size() > 0);
        continue;
      }

      // get alpha-value of node
      alphas[iteration] = nodeToPrune.m_Alpha;

      // log error
      if (errors != null) {
        eval = new Evaluation(test);
        eval.evaluateModel(this, test);
        errors[iteration] = eval.errorRate();
      }
      preAlpha = nodeToPrune.m_Alpha;

      // update errors/alphas
      treeErrors();
      calculateAlphas();

      nodeList = getInnerNodes();
      prune = (nodeList.size() > 0);
    }

    // set last alpha 1 to indicate end
    alphas[iteration + 1] = 1.0;
    return iteration;
  }

  /**
   * Method to "unprune" the CART tree. Sets all leaf-fields to false. Faster
   * than re-growing the tree because CART do not have to be fit again.
   */
  protected void unprune() {
    if (m_Successors != null) {
      m_isLeaf = false;
      for (SimpleCart m_Successor : m_Successors) {
        m_Successor.unprune();
      }
    }
  }

  /**
   * Compute distributions, proportions and total weights of two successor nodes
   * for a given numeric attribute.
   * 
   * @param props proportions of each two branches for each attribute
   * @param dists class distributions of two branches for each attribute
   * @param att numeric att split on
   * @param sortedIndices sorted indices of instances for the attirubte
   * @param weights weights of instances for the attirbute
   * @param subsetWeights total weight of two branches split based on the
   *          attribute
   * @param giniGains Gini gains for each attribute
   * @param data training instances
   * @return Gini gain the given numeric attribute
   * @throws Exception if something goes wrong
   */
  protected double numericDistribution(double[][] props, double[][][] dists,
    Attribute att, int[] sortedIndices, double[] weights,
    double[][] subsetWeights, double[] giniGains, Instances data)
    throws Exception {

    double splitPoint = Double.NaN;
    double[][] dist = null;
    int numClasses = data.numClasses();
    int i; // differ instances with or without missing values

    double[][] currDist = new double[2][numClasses];
    dist = new double[2][numClasses];

    // Move all instances without missing values into second subset
    double[] parentDist = new double[numClasses];
    int missingStart = 0;
    for (int j = 0; j < sortedIndices.length; j++) {
      Instance inst = data.instance(sortedIndices[j]);
      if (!inst.isMissing(att)) {
        missingStart++;
        currDist[1][(int) inst.classValue()] += weights[j];
      }
      parentDist[(int) inst.classValue()] += weights[j];
    }
    System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);

    // Try all possible split points
    double currSplit = data.instance(sortedIndices[0]).value(att);
    double currGiniGain;
    double bestGiniGain = -Double.MAX_VALUE;

    for (i = 0; i < sortedIndices.length; i++) {
      Instance inst = data.instance(sortedIndices[i]);
      if (inst.isMissing(att)) {
        break;
      }
      if (inst.value(att) > currSplit) {

        double[][] tempDist = new double[2][numClasses];
        for (int k = 0; k < 2; k++) {
          // tempDist[k] = currDist[k];
          System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length);
        }

        double[] tempProps = new double[2];
        for (int k = 0; k < 2; k++) {
          tempProps[k] = Utils.sum(tempDist[k]);
        }

        if (Utils.sum(tempProps) != 0) {
          Utils.normalize(tempProps);
        }

        // split missing values
        int index = missingStart;
        while (index < sortedIndices.length) {
          Instance insta = data.instance(sortedIndices[index]);
          for (int j = 0; j < 2; j++) {
            tempDist[j][(int) insta.classValue()] += tempProps[j]
              * weights[index];
          }
          index++;
        }

        currGiniGain = computeGiniGain(parentDist, tempDist);

        if (currGiniGain > bestGiniGain) {
          bestGiniGain = currGiniGain;

          // clean split point
          // splitPoint = Math.rint((inst.value(att) +
          // currSplit)/2.0*100000)/100000.0;
          splitPoint = (inst.value(att) + currSplit) / 2.0;

          for (int j = 0; j < currDist.length; j++) {
            System.arraycopy(tempDist[j], 0, dist[j], 0, dist[j].length);
          }
        }
      }
      currSplit = inst.value(att);
      currDist[0][(int) inst.classValue()] += weights[i];
      currDist[1][(int) inst.classValue()] -= weights[i];
    }

    // Compute weights
    int attIndex = att.index();
    props[attIndex] = new double[2];
    for (int k = 0; k < 2; k++) {
      props[attIndex][k] = Utils.sum(dist[k]);
    }
    if (Utils.sum(props[attIndex]) != 0) {
      Utils.normalize(props[attIndex]);
    }

    // Compute subset weights
    subsetWeights[attIndex] = new double[2];
    for (int j = 0; j < 2; j++) {
      subsetWeights[attIndex][j] += Utils.sum(dist[j]);
    }

    // clean Gini gain
    // giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
    giniGains[attIndex] = bestGiniGain;
    dists[attIndex] = dist;

    return splitPoint;
  }

  /**
   * Compute distributions, proportions and total weights of two successor nodes
   * for a given nominal attribute.
   * 
   * @param props proportions of each two branches for each attribute
   * @param dists class distributions of two branches for each attribute
   * @param att numeric att split on
   * @param sortedIndices sorted indices of instances for the attirubte
   * @param weights weights of instances for the attirbute
   * @param subsetWeights total weight of two branches split based on the
   *          attribute
   * @param giniGains Gini gains for each attribute
   * @param data training instances
   * @param useHeuristic if use heuristic search
   * @return Gini gain for the given nominal attribute
   * @throws Exception if something goes wrong
   */
  protected String nominalDistribution(double[][] props, double[][][] dists,
    Attribute att, int[] sortedIndices, double[] weights,
    double[][] subsetWeights, double[] giniGains, Instances data,
    boolean useHeuristic) throws Exception {

    String[] values = new String[att.numValues()];
    int numCat = values.length; // number of values of the attribute
    int numClasses = data.numClasses();

    String bestSplitString = "";
    double bestGiniGain = -Double.MAX_VALUE;

    // class frequency for each value
    int[] classFreq = new int[numCat];
    for (int j = 0; j < numCat; j++) {
      classFreq[j] = 0;
    }

    double[] parentDist = new double[numClasses];
    double[][] currDist = new double[2][numClasses];
    double[][] dist = new double[2][numClasses];
    int missingStart = 0;

    for (int i = 0; i < sortedIndices.length; i++) {
      Instance inst = data.instance(sortedIndices[i]);
      if (!inst.isMissing(att)) {
        missingStart++;
        classFreq[(int) inst.value(att)]++;
      }
      parentDist[(int) inst.classValue()] += weights[i];
    }

    // count the number of values that class frequency is not 0
    int nonEmpty = 0;
    for (int j = 0; j < numCat; j++) {
      if (classFreq[j] != 0) {
        nonEmpty++;
      }
    }

    // attribute values that class frequency is not 0
    String[] nonEmptyValues = new String[nonEmpty];
    int nonEmptyIndex = 0;
    for (int j = 0; j < numCat; j++) {
      if (classFreq[j] != 0) {
        nonEmptyValues[nonEmptyIndex] = att.value(j);
        nonEmptyIndex++;
      }
    }

    // attribute values that class frequency is 0
    int empty = numCat - nonEmpty;
    String[] emptyValues = new String[empty];
    int emptyIndex = 0;
    for (int j = 0; j < numCat; j++) {
      if (classFreq[j] == 0) {
        emptyValues[emptyIndex] = att.value(j);
        emptyIndex++;
      }
    }

    if (nonEmpty <= 1) {
      giniGains[att.index()] = 0;
      return "";
    }

    // for tow-class probloms
    if (data.numClasses() == 2) {

      // // Firstly, for attribute values which class frequency is not zero

      // probability of class 0 for each attribute value
      double[] pClass0 = new double[nonEmpty];
      // class distribution for each attribute value
      double[][] valDist = new double[nonEmpty][2];

      for (int j = 0; j < nonEmpty; j++) {
        for (int k = 0; k < 2; k++) {
          valDist[j][k] = 0;
        }
      }

      for (int sortedIndice : sortedIndices) {
        Instance inst = data.instance(sortedIndice);
        if (inst.isMissing(att)) {
          break;
        }

        for (int j = 0; j < nonEmpty; j++) {
          if (att.value((int) inst.value(att)).compareTo(nonEmptyValues[j]) == 0) {
            valDist[j][(int) inst.classValue()] += inst.weight();
            break;
          }
        }
      }

      for (int j = 0; j < nonEmpty; j++) {
        double distSum = Utils.sum(valDist[j]);
        if (distSum == 0) {
          pClass0[j] = 0;
        } else {
          pClass0[j] = valDist[j][0] / distSum;
        }
      }

      // sort category according to the probability of the first class
      String[] sortedValues = new String[nonEmpty];
      for (int j = 0; j < nonEmpty; j++) {
        sortedValues[j] = nonEmptyValues[Utils.minIndex(pClass0)];
        pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE;
      }

      // Find a subset of attribute values that maximize Gini decrease

      // for the attribute values that class frequency is not 0
      String tempStr = "";

      for (int j = 0; j < nonEmpty - 1; j++) {
        currDist = new double[2][numClasses];
        if (tempStr == "") {
          tempStr = "(" + sortedValues[j] + ")";
        } else {
          tempStr += "|" + "(" + sortedValues[j] + ")";
        }
        for (int i = 0; i < sortedIndices.length; i++) {
          Instance inst = data.instance(sortedIndices[i]);
          if (inst.isMissing(att)) {
            break;
          }

          if (tempStr.indexOf("(" + att.value((int) inst.value(att)) + ")") != -1) {
            currDist[0][(int) inst.classValue()] += weights[i];
          } else {
            currDist[1][(int) inst.classValue()] += weights[i];
          }
        }

        double[][] tempDist = new double[2][numClasses];
        for (int kk = 0; kk < 2; kk++) {
          tempDist[kk] = currDist[kk];
        }

        double[] tempProps = new double[2];
        for (int kk = 0; kk < 2; kk++) {
          tempProps[kk] = Utils.sum(tempDist[kk]);
        }

        if (Utils.sum(tempProps) != 0) {
          Utils.normalize(tempProps);
        }

        // split missing values
        int mstart = missingStart;
        while (mstart < sortedIndices.length) {
          Instance insta = data.instance(sortedIndices[mstart]);
          for (int jj = 0; jj < 2; jj++) {
            tempDist[jj][(int) insta.classValue()] += tempProps[jj]
              * weights[mstart];
          }
          mstart++;
        }

        double currGiniGain = computeGiniGain(parentDist, tempDist);

        if (currGiniGain > bestGiniGain) {
          bestGiniGain = currGiniGain;
          bestSplitString = tempStr;
          for (int jj = 0; jj < 2; jj++) {
            // dist[jj] = new double[currDist[jj].length];
            System.arraycopy(tempDist[jj], 0, dist[jj], 0, dist[jj].length);
          }
        }
      }
    }

    // multi-class problems - exhaustive search
    else if (!useHeuristic || nonEmpty <= 4) {

      // Firstly, for attribute values which class frequency is not zero
      for (int i = 0; i < (int) Math.pow(2, nonEmpty - 1); i++) {
        String tempStr = "";
        currDist = new double[2][numClasses];
        int mod;
        int bit10 = i;
        for (int j = nonEmpty - 1; j >= 0; j--) {
          mod = bit10 % 2; // convert from 10bit to 2bit
          if (mod == 1) {
            if (tempStr == "") {
              tempStr = "(" + nonEmptyValues[j] + ")";
            } else {
              tempStr += "|" + "(" + nonEmptyValues[j] + ")";
            }
          }
          bit10 = bit10 / 2;
        }
        for (int j = 0; j < sortedIndices.length; j++) {
          Instance inst = data.instance(sortedIndices[j]);
          if (inst.isMissing(att)) {
            break;
          }

          if (tempStr.indexOf("(" + att.value((int) inst.value(att)) + ")") != -1) {
            currDist[0][(int) inst.classValue()] += weights[j];
          } else {
            currDist[1][(int) inst.classValue()] += weights[j];
          }
        }

        double[][] tempDist = new double[2][numClasses];
        for (int k = 0; k < 2; k++) {
          tempDist[k] = currDist[k];
        }

        double[] tempProps = new double[2];
        for (int k = 0; k < 2; k++) {
          tempProps[k] = Utils.sum(tempDist[k]);
        }

        if (Utils.sum(tempProps) != 0) {
          Utils.normalize(tempProps);
        }

        // split missing values
        int index = missingStart;
        while (index < sortedIndices.length) {
          Instance insta = data.instance(sortedIndices[index]);
          for (int j = 0; j < 2; j++) {
            tempDist[j][(int) insta.classValue()] += tempProps[j]
              * weights[index];
          }
          index++;
        }

        double currGiniGain = computeGiniGain(parentDist, tempDist);

        if (currGiniGain > bestGiniGain) {
          bestGiniGain = currGiniGain;
          bestSplitString = tempStr;
          for (int j = 0; j < 2; j++) {
            // dist[jj] = new double[currDist[jj].length];
            System.arraycopy(tempDist[j], 0, dist[j], 0, dist[j].length);
          }
        }
      }
    }

    // huristic search to solve multi-classes problems
    else {
      // Firstly, for attribute values which class frequency is not zero
      int n = nonEmpty;
      int k = data.numClasses(); // number of classes of the data
      double[][] P = new double[n][k]; // class probability matrix
      int[] numInstancesValue = new int[n]; // number of instances for an
                                            // attribute value
      double[] meanClass = new double[k]; // vector of mean class probability
      int numInstances = data.numInstances(); // total number of instances

      // initialize the vector of mean class probability
      for (int j = 0; j < meanClass.length; j++) {
        meanClass[j] = 0;
      }

      for (int j = 0; j < numInstances; j++) {
        Instance inst = data.instance(j);
        int valueIndex = 0; // attribute value index in nonEmptyValues
        for (int i = 0; i < nonEmpty; i++) {
          if (att.value((int) inst.value(att)).compareToIgnoreCase(
            nonEmptyValues[i]) == 0) {
            valueIndex = i;
            break;
          }
        }
        P[valueIndex][(int) inst.classValue()]++;
        numInstancesValue[valueIndex]++;
        meanClass[(int) inst.classValue()]++;
      }

      // calculate the class probability matrix
      for (int i = 0; i < P.length; i++) {
        for (int j = 0; j < P[0].length; j++) {
          if (numInstancesValue[i] == 0) {
            P[i][j] = 0;
          } else {
            P[i][j] /= numInstancesValue[i];
          }
        }
      }

      // calculate the vector of mean class probability
      for (int i = 0; i < meanClass.length; i++) {
        meanClass[i] /= numInstances;
      }

      // calculate the covariance matrix
      double[][] covariance = new double[k][k];
      for (int i1 = 0; i1 < k; i1++) {
        for (int i2 = 0; i2 < k; i2++) {
          double element = 0;
          for (int j = 0; j < n; j++) {
            element += (P[j][i2] - meanClass[i2]) * (P[j][i1] - meanClass[i1])
              * numInstancesValue[j];
          }
          covariance[i1][i2] = element;
        }
      }

      Matrix matrix = new Matrix(covariance);
      weka.core.matrix.EigenvalueDecomposition eigen = new weka.core.matrix.EigenvalueDecomposition(
        matrix);
      double[] eigenValues = eigen.getRealEigenvalues();

      // find index of the largest eigenvalue
      int index = 0;
      double largest = eigenValues[0];
      for (int i = 1; i < eigenValues.length; i++) {
        if (eigenValues[i] > largest) {
          index = i;
          largest = eigenValues[i];
        }
      }

      // calculate the first principle component
      double[] FPC = new double[k];
      Matrix eigenVector = eigen.getV();
      double[][] vectorArray = eigenVector.getArray();
      for (int i = 0; i < FPC.length; i++) {
        FPC[i] = vectorArray[i][index];
      }

      // calculate the first principle component scores
      // System.out.println("the first principle component scores: ");
      double[] Sa = new double[n];
      for (int i = 0; i < Sa.length; i++) {
        Sa[i] = 0;
        for (int j = 0; j < k; j++) {
          Sa[i] += FPC[j] * P[i][j];
        }
      }

      // sort category according to Sa(s)
      double[] pCopy = new double[n];
      System.arraycopy(Sa, 0, pCopy, 0, n);
      String[] sortedValues = new String[n];
      Arrays.sort(Sa);

      for (int j = 0; j < n; j++) {
        sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)];
        pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE;
      }

      // for the attribute values that class frequency is not 0
      String tempStr = "";

      for (int j = 0; j < nonEmpty - 1; j++) {
        currDist = new double[2][numClasses];
        if (tempStr == "") {
          tempStr = "(" + sortedValues[j] + ")";
        } else {
          tempStr += "|" + "(" + sortedValues[j] + ")";
        }
        for (int i = 0; i < sortedIndices.length; i++) {
          Instance inst = data.instance(sortedIndices[i]);
          if (inst.isMissing(att)) {
            break;
          }

          if (tempStr.indexOf("(" + att.value((int) inst.value(att)) + ")") != -1) {
            currDist[0][(int) inst.classValue()] += weights[i];
          } else {
            currDist[1][(int) inst.classValue()] += weights[i];
          }
        }

        double[][] tempDist = new double[2][numClasses];
        for (int kk = 0; kk < 2; kk++) {
          tempDist[kk] = currDist[kk];
        }

        double[] tempProps = new double[2];
        for (int kk = 0; kk < 2; kk++) {
          tempProps[kk] = Utils.sum(tempDist[kk]);
        }

        if (Utils.sum(tempProps) != 0) {
          Utils.normalize(tempProps);
        }

        // split missing values
        int mstart = missingStart;
        while (mstart < sortedIndices.length) {
          Instance insta = data.instance(sortedIndices[mstart]);
          for (int jj = 0; jj < 2; jj++) {
            tempDist[jj][(int) insta.classValue()] += tempProps[jj]
              * weights[mstart];
          }
          mstart++;
        }

        double currGiniGain = computeGiniGain(parentDist, tempDist);

        if (currGiniGain > bestGiniGain) {
          bestGiniGain = currGiniGain;
          bestSplitString = tempStr;
          for (int jj = 0; jj < 2; jj++) {
            // dist[jj] = new double[currDist[jj].length];
            System.arraycopy(tempDist[jj], 0, dist[jj], 0, dist[jj].length);
          }
        }
      }
    }

    // Compute weights
    int attIndex = att.index();
    props[attIndex] = new double[2];
    for (int k = 0; k < 2; k++) {
      props[attIndex][k] = Utils.sum(dist[k]);
    }

    if (!(Utils.sum(props[attIndex]) > 0)) {
      for (int k = 0; k < props[attIndex].length; k++) {
        props[attIndex][k] = 1.0 / props[attIndex].length;
      }
    } else {
      Utils.normalize(props[attIndex]);
    }

    // Compute subset weights
    subsetWeights[attIndex] = new double[2];
    for (int j = 0; j < 2; j++) {
      subsetWeights[attIndex][j] += Utils.sum(dist[j]);
    }

    // Then, for the attribute values that class frequency is 0, split it into
    // the
    // most frequent branch
    for (int j = 0; j < empty; j++) {
      if (props[attIndex][0] >= props[attIndex][1]) {
        if (bestSplitString == "") {
          bestSplitString = "(" + emptyValues[j] + ")";
        } else {
          bestSplitString += "|" + "(" + emptyValues[j] + ")";
        }
      }
    }

    // clean Gini gain for the attribute
    // giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
    giniGains[attIndex] = bestGiniGain;

    dists[attIndex] = dist;
    return bestSplitString;
  }

  /**
   * Split data into two subsets and store sorted indices and weights for two
   * successor nodes.
   * 
   * @param subsetIndices sorted indecis of instances for each attribute for two
   *          successor node
   * @param subsetWeights weights of instances for each attribute for two
   *          successor node
   * @param att attribute the split based on
   * @param splitPoint split point the split based on if att is numeric
   * @param splitStr split subset the split based on if att is nominal
   * @param sortedIndices sorted indices of the instances to be split
   * @param weights weights of the instances to bes split
   * @param data training data
   * @throws Exception if something goes wrong
   */
  protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights,
    Attribute att, double splitPoint, String splitStr, int[][] sortedIndices,
    double[][] weights, Instances data) throws Exception {

    int j;
    // For each attribute
    for (int i = 0; i < data.numAttributes(); i++) {
      if (i == data.classIndex()) {
        continue;
      }
      int[] num = new int[2];
      for (int k = 0; k < 2; k++) {
        subsetIndices[k][i] = new int[sortedIndices[i].length];
        subsetWeights[k][i] = new double[weights[i].length];
      }

      for (j = 0; j < sortedIndices[i].length; j++) {
        Instance inst = data.instance(sortedIndices[i][j]);
        if (inst.isMissing(att)) {
          // Split instance up
          for (int k = 0; k < 2; k++) {
            if (m_Props[k] > 0) {
              subsetIndices[k][i][num[k]] = sortedIndices[i][j];
              subsetWeights[k][i][num[k]] = m_Props[k] * weights[i][j];
              num[k]++;
            }
          }
        } else {
          int subset;
          if (att.isNumeric()) {
            subset = (inst.value(att) < splitPoint) ? 0 : 1;
          } else { // nominal attribute
            if (splitStr.indexOf("(" + att.value((int) inst.value(att.index()))
              + ")") != -1) {
              subset = 0;
            } else {
              subset = 1;
            }
          }
          subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
          subsetWeights[subset][i][num[subset]] = weights[i][j];
          num[subset]++;
        }
      }

      // Trim arrays
      for (int k = 0; k < 2; k++) {
        int[] copy = new int[num[k]];
        System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
        subsetIndices[k][i] = copy;
        double[] copyWeights = new double[num[k]];
        System.arraycopy(subsetWeights[k][i], 0, copyWeights, 0, num[k]);
        subsetWeights[k][i] = copyWeights;
      }
    }
  }

  /**
   * Updates the numIncorrectModel field for all nodes when subtree (to be
   * pruned) is rooted. This is needed for calculating the alpha-values.
   * 
   * @throws Exception if something goes wrong
   */
  public void modelErrors() throws Exception {
    Evaluation eval = new Evaluation(m_train);

    if (!m_isLeaf) {
      m_isLeaf = true; // temporarily make leaf

      // calculate distribution for evaluation
      eval.evaluateModel(this, m_train);
      m_numIncorrectModel = eval.incorrect();

      m_isLeaf = false;

      for (SimpleCart m_Successor : m_Successors) {
        m_Successor.modelErrors();
      }

    } else {
      eval.evaluateModel(this, m_train);
      m_numIncorrectModel = eval.incorrect();
    }
  }

  /**
   * Updates the numIncorrectTree field for all nodes. This is needed for
   * calculating the alpha-values.
   * 
   * @throws Exception if something goes wrong
   */
  public void treeErrors() throws Exception {
    if (m_isLeaf) {
      m_numIncorrectTree = m_numIncorrectModel;
    } else {
      m_numIncorrectTree = 0;
      for (SimpleCart m_Successor : m_Successors) {
        m_Successor.treeErrors();
        m_numIncorrectTree += m_Successor.m_numIncorrectTree;
      }
    }
  }

  /**
   * Updates the alpha field for all nodes.
   * 
   * @throws Exception if something goes wrong
   */
  public void calculateAlphas() throws Exception {

    if (!m_isLeaf) {
      double errorDiff = m_numIncorrectModel - m_numIncorrectTree;
      if (errorDiff <= 0) {
        // split increases training error (should not normally happen).
        // prune it instantly.
        makeLeaf(m_train);
        m_Alpha = Double.MAX_VALUE;
      } else {
        // compute alpha
        errorDiff /= m_totalTrainInstances;
        m_Alpha = errorDiff / (numLeaves() - 1);
        long alphaLong = Math.round(m_Alpha * Math.pow(10, 10));
        m_Alpha = alphaLong / Math.pow(10, 10);
        for (SimpleCart m_Successor : m_Successors) {
          m_Successor.calculateAlphas();
        }
      }
    } else {
      // alpha = infinite for leaves (do not want to prune)
      m_Alpha = Double.MAX_VALUE;
    }
  }

  /**
   * Find the node with minimal alpha value. If two nodes have the same alpha,
   * choose the one with more leave nodes.
   * 
   * @param nodeList list of inner nodes
   * @return the node to be pruned
   */
  protected SimpleCart nodeToPrune(Vector<SimpleCart> nodeList) {
    if (nodeList.size() == 0) {
      return null;
    }
    if (nodeList.size() == 1) {
      return nodeList.elementAt(0);
    }
    SimpleCart returnNode = nodeList.elementAt(0);
    double baseAlpha = returnNode.m_Alpha;
    for (int i = 1; i < nodeList.size(); i++) {
      SimpleCart node = nodeList.elementAt(i);
      if (node.m_Alpha < baseAlpha) {
        baseAlpha = node.m_Alpha;
        returnNode = node;
      } else if (node.m_Alpha == baseAlpha) { // break tie
        if (node.numLeaves() > returnNode.numLeaves()) {
          returnNode = node;
        }
      }
    }
    return returnNode;
  }

  /**
   * Compute sorted indices, weights and class probabilities for a given
   * dataset. Return total weights of the data at the node.
   * 
   * @param data training data
   * @param sortedIndices sorted indices of instances at the node
   * @param weights weights of instances at the node
   * @param classProbs class probabilities at the node
   * @return total weights of instances at the node
   * @throws Exception if something goes wrong
   */
  protected double computeSortedInfo(Instances data, int[][] sortedIndices,
    double[][] weights, double[] classProbs) throws Exception {

    // Create array of sorted indices and weights
    double[] vals = new double[data.numInstances()];
    for (int j = 0; j < data.numAttributes(); j++) {
      if (j == data.classIndex()) {
        continue;
      }
      weights[j] = new double[data.numInstances()];

      if (data.attribute(j).isNominal()) {

        // Handling nominal attributes. Putting indices of
        // instances with missing values at the end.
        sortedIndices[j] = new int[data.numInstances()];
        int count = 0;
        for (int i = 0; i < data.numInstances(); i++) {
          Instance inst = data.instance(i);
          if (!inst.isMissing(j)) {
            sortedIndices[j][count] = i;
            weights[j][count] = inst.weight();
            count++;
          }
        }
        for (int i = 0; i < data.numInstances(); i++) {
          Instance inst = data.instance(i);
          if (inst.isMissing(j)) {
            sortedIndices[j][count] = i;
            weights[j][count] = inst.weight();
            count++;
          }
        }
      } else {

        // Sorted indices are computed for numeric attributes
        // missing values instances are put to end
        for (int i = 0; i < data.numInstances(); i++) {
          Instance inst = data.instance(i);
          vals[i] = inst.value(j);
        }
        sortedIndices[j] = Utils.sort(vals);
        for (int i = 0; i < data.numInstances(); i++) {
          weights[j][i] = data.instance(sortedIndices[j][i]).weight();
        }
      }
    }

    // Compute initial class counts
    double totalWeight = 0;
    for (int i = 0; i < data.numInstances(); i++) {
      Instance inst = data.instance(i);
      classProbs[(int) inst.classValue()] += inst.weight();
      totalWeight += inst.weight();
    }

    return totalWeight;
  }

  /**
   * Compute and return gini gain for given distributions of a node and its
   * successor nodes.
   * 
   * @param parentDist class distributions of parent node
   * @param childDist class distributions of successor nodes
   * @return Gini gain computed
   */
  protected double computeGiniGain(double[] parentDist, double[][] childDist) {
    double totalWeight = Utils.sum(parentDist);
    if (totalWeight == 0) {
      return 0;
    }

    double leftWeight = Utils.sum(childDist[0]);
    double rightWeight = Utils.sum(childDist[1]);

    double parentGini = computeGini(parentDist, totalWeight);
    double leftGini = computeGini(childDist[0], leftWeight);
    double rightGini = computeGini(childDist[1], rightWeight);

    return parentGini - leftWeight / totalWeight * leftGini - rightWeight
      / totalWeight * rightGini;
  }

  /**
   * Compute and return gini index for a given distribution of a node.
   * 
   * @param dist class distributions
   * @param total class distributions
   * @return Gini index of the class distributions
   */
  protected double computeGini(double[] dist, double total) {
    if (total == 0) {
      return 0;
    }
    double val = 0;
    for (double element : dist) {
      val += (element / total) * (element / total);
    }
    return 1 - val;
  }

  /**
   * Computes class probabilities for instance using the decision tree.
   * 
   * @param instance the instance for which class probabilities is to be
   *          computed
   * @return the class probabilities for the given instance
   * @throws Exception if something goes wrong
   */
  @Override
  public double[] distributionForInstance(Instance instance) throws Exception {
    if (!m_isLeaf) {
      // value of split attribute is missing
      if (instance.isMissing(m_Attribute)) {
        double[] returnedDist = new double[m_ClassProbs.length];

        for (int i = 0; i < m_Successors.length; i++) {
          double[] help = m_Successors[i].distributionForInstance(instance);
          if (help != null) {
            for (int j = 0; j < help.length; j++) {
              returnedDist[j] += m_Props[i] * help[j];
            }
          }
        }
        return returnedDist;
      }

      // split attribute is nonimal
      else if (m_Attribute.isNominal()) {
        if (m_SplitString.indexOf("("
          + m_Attribute.value((int) instance.value(m_Attribute)) + ")") != -1) {
          return m_Successors[0].distributionForInstance(instance);
        } else {
          return m_Successors[1].distributionForInstance(instance);
        }
      }

      // split attribute is numeric
      else {
        if (instance.value(m_Attribute) < m_SplitValue) {
          return m_Successors[0].distributionForInstance(instance);
        } else {
          return m_Successors[1].distributionForInstance(instance);
        }
      }
    } else {
      return m_ClassProbs;
    }
  }

  /**
   * Make the node leaf node.
   * 
   * @param data trainging data
   */
  protected void makeLeaf(Instances data) {
    m_Attribute = null;
    m_isLeaf = true;
    m_ClassValue = Utils.maxIndex(m_ClassProbs);
    m_ClassAttribute = data.classAttribute();
  }

  /**
   * Prints the decision tree using the protected toString method from below.
   * 
   * @return a textual description of the classifier
   */
  @Override
  public String toString() {
    if ((m_ClassProbs == null) && (m_Successors == null)) {
      return "CART Tree: No model built yet.";
    }

    return "CART Decision Tree\n" + toString(0) + "\n\n"
      + "Number of Leaf Nodes: " + numLeaves() + "\n\n" + "Size of the Tree: "
      + numNodes();
  }

  /**
   * Outputs a tree at a certain level.
   * 
   * @param level the level at which the tree is to be printed
   * @return a tree at a certain level
   */
  protected String toString(int level) {

    StringBuffer text = new StringBuffer();
    // if leaf nodes
    if (m_Attribute == null) {
      if (Utils.isMissingValue(m_ClassValue)) {
        text.append(": null");
      } else {
        double correctNum = (int) (m_Distribution[Utils
          .maxIndex(m_Distribution)] * 100) / 100.0;
        double wrongNum = (int) ((Utils.sum(m_Distribution) - m_Distribution[Utils
          .maxIndex(m_Distribution)]) * 100) / 100.0;
        String str = "(" + correctNum + "/" + wrongNum + ")";
        text.append(": " + m_ClassAttribute.value((int) m_ClassValue) + str);
      }
    } else {
      for (int j = 0; j < 2; j++) {
        text.append("\n");
        for (int i = 0; i < level; i++) {
          text.append("|  ");
        }
        if (j == 0) {
          if (m_Attribute.isNumeric()) {
            text.append(m_Attribute.name() + " < " + m_SplitValue);
          } else {
            text.append(m_Attribute.name() + "=" + m_SplitString);
          }
        } else {
          if (m_Attribute.isNumeric()) {
            text.append(m_Attribute.name() + " >= " + m_SplitValue);
          } else {
            text.append(m_Attribute.name() + "!=" + m_SplitString);
          }
        }
        text.append(m_Successors[j].toString(level + 1));
      }
    }
    return text.toString();
  }

  /**
   * Compute size of the tree.
   * 
   * @return size of the tree
   */
  public int numNodes() {
    if (m_isLeaf) {
      return 1;
    } else {
      int size = 1;
      for (SimpleCart m_Successor : m_Successors) {
        size += m_Successor.numNodes();
      }
      return size;
    }
  }

  /**
   * Method to count the number of inner nodes in the tree.
   * 
   * @return the number of inner nodes
   */
  public int numInnerNodes() {
    if (m_Attribute == null) {
      return 0;
    }
    int numNodes = 1;
    for (SimpleCart m_Successor : m_Successors) {
      numNodes += m_Successor.numInnerNodes();
    }
    return numNodes;
  }

  /**
   * Return a list of all inner nodes in the tree.
   * 
   * @return the list of all inner nodes
   */
  protected Vector<SimpleCart> getInnerNodes() {
    Vector<SimpleCart> nodeList = new Vector<SimpleCart>();
    fillInnerNodes(nodeList);
    return nodeList;
  }

  /**
   * Fills a list with all inner nodes in the tree.
   * 
   * @param nodeList the list to be filled
   */
  protected void fillInnerNodes(Vector<SimpleCart> nodeList) {
    if (!m_isLeaf) {
      nodeList.add(this);
      for (SimpleCart m_Successor : m_Successors) {
        m_Successor.fillInnerNodes(nodeList);
      }
    }
  }

  /**
   * Compute number of leaf nodes.
   * 
   * @return number of leaf nodes
   */
  public int numLeaves() {
    if (m_isLeaf) {
      return 1;
    } else {
      int size = 0;
      for (SimpleCart m_Successor : m_Successors) {
        size += m_Successor.numLeaves();
      }
      return size;
    }
  }

  /**
   * Returns an enumeration describing the available options.
   * 
   * @return an enumeration of all the available options.
   */
  @Override
  public Enumeration<Option> listOptions() {

    Vector<Option> result = new Vector<Option>(6);

    result.addElement(new Option(
      "\tThe minimal number of instances at the terminal nodes.\n"
        + "\t(default 2)", "M", 1, "-M <min no>"));

    result.addElement(new Option(
      "\tThe number of folds used in the minimal cost-complexity pruning.\n"
        + "\t(default 5)", "N", 1, "-N <num folds>"));

    result
      .addElement(new Option(
        "\tDon't use the minimal cost-complexity pruning.\n"
          + "\t(default yes).", "U", 0, "-U"));

    result.addElement(new Option(
      "\tDon't use the heuristic method for binary split.\n"
        + "\t(default true).", "H", 0, "-H"));

    result.addElement(new Option("\tUse 1 SE rule to make pruning decision.\n"
      + "\t(default no).", "A", 0, "-A"));

    result.addElement(new Option("\tPercentage of training data size (0-1].\n"
      + "\t(default 1).", "C", 1, "-C"));

    result.addAll(Collections.list(super.listOptions()));

    return result.elements();
  }

  /**
   * Parses a given list of options.
   * <p/>
   * 
   * <!-- options-start --> Valid options are:
   * <p/>
   * 
   * <pre>
   * -S &lt;num&gt;
   *  Random number seed.
   *  (default 1)
   * </pre>
   * 
   * <pre>
   * -M &lt;min no&gt;
   *  The minimal number of instances at the terminal nodes.
   *  (default 2)
   * </pre>
   * 
   * <pre>
   * -N &lt;num folds&gt;
   *  The number of folds used in the minimal cost-complexity pruning.
   *  (default 5)
   * </pre>
   * 
   * <pre>
   * -U
   *  Don't use the minimal cost-complexity pruning.
   *  (default yes).
   * </pre>
   * 
   * <pre>
   * -H
   *  Don't use the heuristic method for binary split.
   *  (default true).
   * </pre>
   * 
   * <pre>
   * -A
   *  Use 1 SE rule to make pruning decision.
   *  (default no).
   * </pre>
   * 
   * <pre>
   * -C
   *  Percentage of training data size (0-1].
   *  (default 1).
   * </pre>
   * 
   * <!-- options-end -->
   * 
   * @param options the list of options as an array of strings
   * @throws Exception if an options is not supported
   */
  @Override
  public void setOptions(String[] options) throws Exception {
    String tmpStr;

    tmpStr = Utils.getOption('M', options);
    if (tmpStr.length() != 0) {
      setMinNumObj(Double.parseDouble(tmpStr));
    } else {
      setMinNumObj(2);
    }

    tmpStr = Utils.getOption('N', options);
    if (tmpStr.length() != 0) {
      setNumFoldsPruning(Integer.parseInt(tmpStr));
    } else {
      setNumFoldsPruning(5);
    }

    setUsePrune(!Utils.getFlag('U', options));
    setHeuristic(!Utils.getFlag('H', options));
    setUseOneSE(Utils.getFlag('A', options));

    tmpStr = Utils.getOption('C', options);
    if (tmpStr.length() != 0) {
      setSizePer(Double.parseDouble(tmpStr));
    } else {
      setSizePer(1);
    }

    super.setOptions(options);

    Utils.checkForRemainingOptions(options);
  }

  /**
   * Gets the current settings of the classifier.
   * 
   * @return the current setting of the classifier
   */
  @Override
  public String[] getOptions() {

    Vector<String> result = new Vector<String>();

    result.add("-M");
    result.add("" + getMinNumObj());

    result.add("-N");
    result.add("" + getNumFoldsPruning());

    if (!getUsePrune()) {
      result.add("-U");
    }

    if (!getHeuristic()) {
      result.add("-H");
    }

    if (getUseOneSE()) {
      result.add("-A");
    }

    result.add("-C");
    result.add("" + getSizePer());

    Collections.addAll(result, super.getOptions());

    return result.toArray(new String[result.size()]);
  }

  /**
   * Return an enumeration of the measure names.
   * 
   * @return an enumeration of the measure names
   */
  @Override
  public Enumeration<String> enumerateMeasures() {
    Vector<String> result = new Vector<String>();

    result.addElement("measureTreeSize");

    return result.elements();
  }

  /**
   * Return number of tree size.
   * 
   * @return number of tree size
   */
  public double measureTreeSize() {
    return numNodes();
  }

  /**
   * Returns the value of the named measure.
   * 
   * @param additionalMeasureName the name of the measure to query for its value
   * @return the value of the named measure
   * @throws IllegalArgumentException if the named measure is not supported
   */
  @Override
  public double getMeasure(String additionalMeasureName) {
    if (additionalMeasureName.compareToIgnoreCase("measureTreeSize") == 0) {
      return measureTreeSize();
    } else {
      throw new IllegalArgumentException(additionalMeasureName
        + " not supported (Cart pruning)");
    }
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String minNumObjTipText() {
    return "The minimal number of observations at the terminal nodes (default 2).";
  }

  /**
   * Set minimal number of instances at the terminal nodes.
   * 
   * @param value minimal number of instances at the terminal nodes
   */
  public void setMinNumObj(double value) {
    m_minNumObj = value;
  }

  /**
   * Get minimal number of instances at the terminal nodes.
   * 
   * @return minimal number of instances at the terminal nodes
   */
  public double getMinNumObj() {
    return m_minNumObj;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String numFoldsPruningTipText() {
    return "The number of folds in the internal cross-validation (default 5).";
  }

  /**
   * Set number of folds in internal cross-validation.
   * 
   * @param value number of folds in internal cross-validation.
   */
  public void setNumFoldsPruning(int value) {
    m_numFoldsPruning = value;
  }

  /**
   * Set number of folds in internal cross-validation.
   * 
   * @return number of folds in internal cross-validation.
   */
  public int getNumFoldsPruning() {
    return m_numFoldsPruning;
  }

  /**
   * Return the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui.
   */
  public String usePruneTipText() {
    return "Use minimal cost-complexity pruning (default yes).";
  }

  /**
   * Set if use minimal cost-complexity pruning.
   * 
   * @param value if use minimal cost-complexity pruning
   */
  public void setUsePrune(boolean value) {
    m_Prune = value;
  }

  /**
   * Get if use minimal cost-complexity pruning.
   * 
   * @return if use minimal cost-complexity pruning
   */
  public boolean getUsePrune() {
    return m_Prune;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui.
   */
  public String heuristicTipText() {
    return "If heuristic search is used for binary split for nominal attributes "
      + "in multi-class problems (default yes).";
  }

  /**
   * Set if use heuristic search for nominal attributes in multi-class problems.
   * 
   * @param value if use heuristic search for nominal attributes in multi-class
   *          problems
   */
  public void setHeuristic(boolean value) {
    m_Heuristic = value;
  }

  /**
   * Get if use heuristic search for nominal attributes in multi-class problems.
   * 
   * @return if use heuristic search for nominal attributes in multi-class
   *         problems
   */
  public boolean getHeuristic() {
    return m_Heuristic;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui.
   */
  public String useOneSETipText() {
    return "Use the 1SE rule to make pruning decisoin.";
  }

  /**
   * Set if use the 1SE rule to choose final model.
   * 
   * @param value if use the 1SE rule to choose final model
   */
  public void setUseOneSE(boolean value) {
    m_UseOneSE = value;
  }

  /**
   * Get if use the 1SE rule to choose final model.
   * 
   * @return if use the 1SE rule to choose final model
   */
  public boolean getUseOneSE() {
    return m_UseOneSE;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui.
   */
  public String sizePerTipText() {
    return "The percentage of the training set size (0-1, 0 not included).";
  }

  /**
   * Set training set size.
   * 
   * @param value training set size
   */
  public void setSizePer(double value) {
    if ((value <= 0) || (value > 1)) {
      System.err
        .println("The percentage of the training set size must be in range 0 to 1 "
          + "(0 not included) - ignored!");
    } else {
      m_SizePer = value;
    }
  }

  /**
   * Get training set size.
   * 
   * @return training set size
   */
  public double getSizePer() {
    return m_SizePer;
  }

  /**
   * Returns the revision string.
   * 
   * @return the revision
   */
  @Override
  public String getRevision() {
    return RevisionUtils.extract("$Revision$");
  }

  /**
   * Main method.
   * 
   * @param args the options for the classifier
   */
  public static void main(String[] args) {
    runClassifier(new SimpleCart(), args);
  }
}
